import java.util.HashMap;

public class FindMaxLength {

    public int findMaxLength(int[] nums) {
        HashMap<Integer, Integer> maps = new HashMap<>();
        int ans = 0;
        int sum = 0;
        for (int i = 0; i < nums.length; i++) {
            maps.put(sum, maps.getOrDefault(sum, i));
            sum += nums[i];
            if (nums[i] == 0) sum -= 1;
            ans = Math.max(ans, i - maps.getOrDefault(sum, i + 1) + 1);
        }
        return ans;
    }
}
